import scipy

import cv2 as cv
import numpy as np

from .base import BaseEvaluator

# Python version of Naturalness Image Quality Evaluator (NIQE)
# Original source code: https://github.com/dsoellinger/blind_image_quality_toolbox/tree/master/%2Bniqe

def create_evaluator():
  return NIQEEvaluator()


class NIQEEvaluator(BaseEvaluator):

  def __init__(self):
    super().__init__()

    self._init_model_params()

    self.config = {
      'blocksizerow': 96,
      'blocksizecol': 96,
      'featnum': 18,
      'scalenum': 2
    }


  def evaluate(self, output_image, truth_image):
    # NIQE is a no-reference metric: it only depends on output_image
    return self._compute_quality(output_image)
  

  def _compute_quality(self, image):
    if (len(image.shape) == 3):
      if (image.shape[2] == 3):
        image = cv.cvtColor(image, cv.COLOR_RGB2GRAY)
      image = image[:, :, 0]
    
    image = image.astype(np.float64)
    row = image.shape[0]
    col = image.shape[1]

    block_rownum = int(np.floor(row / self.config['blocksizerow']))
    block_colnum = int(np.floor(col / self.config['blocksizecol']))
    image = image[0:(block_rownum * self.config['blocksizerow']), 0:(block_colnum * self.config['blocksizecol'])]

    window = self._gaussian_filter_2d(shape=(7, 7), sigma=(7.0/6.0))
    window = window / np.sum(window)

    feat = []

    for itr_scale in range(1, self.config['scalenum']+1):
      mu = scipy.ndimage.filters.correlate(image, window, mode='nearest')
      mu_sq = mu * mu
      sigma = np.sqrt(np.abs(scipy.ndimage.filters.correlate(image * image, window, mode='nearest') - mu_sq))
      structdis = (image - mu) / (sigma + 1)

      feat_scale = []

      itr_blocksizerow = int(self.config['blocksizerow'] / itr_scale)
      itr_blocksizecol = int(self.config['blocksizecol'] / itr_scale)
      itr_block_rownum = block_rownum
      itr_block_colnum = block_colnum

      for row_index in range(itr_block_rownum):
        for col_index in range(itr_block_colnum):
          row_start = int(row_index * itr_blocksizerow)
          row_end = row_start + itr_blocksizerow
          col_start = int(col_index * itr_blocksizecol)
          col_end = col_start + itr_blocksizecol

          structdis_block = structdis[row_start:row_end, col_start:col_end]
          feat_block = self._compute_feature(structdis_block)
          feat_scale.append(feat_block)
      
      feat.append(feat_scale)

      image = cv.resize(image, None, fx=0.5, fy=0.5)
    
    feat = np.array(feat) # [scalenum, num_blocks, featnum]
    feat = feat.transpose([1, 0, 2]) # [num_blocks, scalenum, featnum]
    feat = feat.reshape([feat.shape[0], feat.shape[1] * feat.shape[2]]) # [num_blocks, scalenum*featnum]

    distparam = feat
    mu_distparam = np.nanmean(distparam, axis=0, keepdims=True)
    cov_distparam = np.cov(distparam.transpose())

    pinv_matrix = (self.model_param_cov + cov_distparam) / 2.0
    invcov_param = np.linalg.pinv(pinv_matrix, rcond=1e-4)
    quality = np.sqrt(np.matmul(np.matmul((self.model_param_mu - mu_distparam), invcov_param), (self.model_param_mu - mu_distparam).transpose()))
    quality = quality.flatten()[0]

    return quality
  
  
  def _compute_feature(self, structdis):
    gamma = scipy.special.gamma

    feat = []

    alpha, betal, betar = self._estimate_aggd_param(structdis.flatten())
    feat.extend([alpha, ((betal + betar) / 2.0)])

    shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
    for shift in shifts:
      shifted_structdis = np.roll(structdis, shift=shift, axis=[0, 1])
      pair = structdis.flatten() * shifted_structdis.flatten()
      alpha, betal, betar = self._estimate_aggd_param(pair)
      meanparam = (betar - betal) * (gamma(2.0 / alpha) / gamma(1.0 / alpha))

      feat.extend([alpha, meanparam, betal, betar])
    
    return feat
  

  def _estimate_aggd_param(self, vec):
    gamma = scipy.special.gamma

    gam = np.arange(start=0.2, stop=10.001, step=0.001)
    r_gam = ((gamma(2.0 / gam) ** 2) / (gamma(1.0 / gam) * gamma(3.0 / gam)))

    leftstd = np.sqrt(np.mean((vec[vec < 0]) ** 2))
    rightstd = np.sqrt(np.mean((vec[vec > 0]) ** 2))

    if (np.isnan(leftstd)):
      leftstd = 1e-8
    if (np.isnan(rightstd)):
      rightstd = 1e-8

    gammahat = leftstd / rightstd
    rhat = (np.mean(np.abs(vec)) ** 2) / (np.mean(vec ** 2))
    rhatnorm = (rhat * ((gammahat ** 3) + 1) * (gammahat + 1)) / (((gammahat ** 2) + 1) ** 2)

    array_position = np.argmin((r_gam - rhatnorm) ** 2)
    alpha = gam[array_position]

    betal = leftstd * np.sqrt(gamma(1.0 / alpha) / gamma(3.0 / alpha))
    betar = rightstd * np.sqrt(gamma(1.0 / alpha) / gamma(3.0 / alpha))

    return alpha, betal, betar


  def _gaussian_filter_2d(self, shape=(5, 5), sigma=1.0):
    m = (shape[0] - 1.0) / 2.0
    n = (shape[1] - 1.0) / 2.0
    y, x = np.ogrid[-m:(m+1), -n:(n+1)]

    h = np.exp(-((x ** 2) + (y ** 2)) / (2.0 * (sigma ** 2)))
    h[h < (np.finfo(h.dtype).eps * np.max(h))] = 0

    h_sum = np.sum(h)
    if (h_sum != 0):
      h = h / h_sum
    
    return h
  

  def _init_model_params(self):

    self.model_param_cov = np.array([[0.45348,0.09610,0.08276,0.01533,0.03672,0.06329,0.08650,-0.00406,0.05017,0.05241,0.08688,-0.01251,0.05819,0.04526,0.08555,-0.00979,0.05768,0.04422,0.29289,0.06635,0.05233,0.00378,0.03336,0.04209,0.05250,-0.00447,0.03779,0.03800,0.05121,-0.01076,0.04313,0.03095,0.04956,-0.00685,0.04139,0.03192],[0.09610,0.03711,0.02155,0.00241,0.01436,0.01934,0.02188,-0.00102,0.01662,0.01751,0.02138,-0.00518,0.01918,0.01474,0.02096,-0.00493,0.01918,0.01436,0.04406,0.02391,0.01243,-0.00054,0.01144,0.01234,0.01216,-0.00200,0.01219,0.01152,0.01098,-0.00372,0.01318,0.00955,0.01051,-0.00323,0.01288,0.00959],[0.08276,0.02155,0.01771,0.00299,0.00856,0.01375,0.01781,0.00001,0.01063,0.01180,0.01718,-0.00218,0.01223,0.00994,0.01691,-0.00204,0.01224,0.00962,0.04987,0.01513,0.01122,0.00035,0.00779,0.00930,0.01113,-0.00043,0.00821,0.00875,0.01003,-0.00187,0.00910,0.00702,0.00967,-0.00164,0.00899,0.00694],[0.01533,0.00241,0.00299,0.00394,-0.00050,0.00310,0.00412,-0.00154,0.00241,0.00110,0.00352,0.00005,0.00175,0.00141,0.00356,-0.00008,0.00181,0.00131,0.01080,0.00307,0.00208,0.00259,0.00026,0.00261,0.00319,-0.00167,0.00251,0.00126,0.00273,-0.00046,0.00199,0.00142,0.00268,-0.00049,0.00200,0.00140],[0.03672,0.01436,0.00856,-0.00050,0.00677,0.00755,0.00814,0.00029,0.00653,0.00749,0.00780,-0.00202,0.00782,0.00600,0.00761,-0.00211,0.00786,0.00581,0.01344,0.00877,0.00475,-0.00138,0.00505,0.00440,0.00415,-0.00008,0.00433,0.00467,0.00358,-0.00118,0.00484,0.00369,0.00343,-0.00119,0.00483,0.00363],[0.06329,0.01934,0.01375,0.00310,0.00755,0.01233,0.01430,-0.00131,0.01034,0.01002,0.01340,-0.00251,0.01121,0.00864,0.01323,-0.00274,0.01138,0.00825,0.02960,0.01291,0.00798,0.00065,0.00629,0.00776,0.00819,-0.00164,0.00752,0.00673,0.00714,-0.00191,0.00770,0.00568,0.00688,-0.00189,0.00767,0.00561],[0.08650,0.02188,0.01781,0.00412,0.00814,0.01430,0.01920,-0.00110,0.01160,0.01187,0.01795,-0.00241,0.01267,0.01010,0.01770,-0.00225,0.01271,0.00974,0.05218,0.01521,0.01129,0.00149,0.00722,0.00968,0.01191,-0.00153,0.00892,0.00853,0.01053,-0.00206,0.00932,0.00700,0.01017,-0.00164,0.00911,0.00704],[-0.00406,-0.00102,0.00001,-0.00154,0.00029,-0.00131,-0.00110,0.00384,-0.00260,0.00065,-0.00061,0.00136,-0.00127,-0.00015,-0.00064,0.00152,-0.00144,0.00003,0.00095,0.00167,0.00091,-0.00150,0.00155,0.00029,-0.00003,0.00258,-0.00068,0.00165,0.00022,0.00038,0.00046,0.00086,0.00016,0.00020,0.00055,0.00072],[0.05017,0.01662,0.01063,0.00241,0.00653,0.01034,0.01160,-0.00260,0.00967,0.00805,0.01050,-0.00283,0.00983,0.00709,0.01035,-0.00318,0.01007,0.00666,0.02047,0.00974,0.00568,0.00061,0.00457,0.00579,0.00600,-0.00228,0.00619,0.00461,0.00515,-0.00171,0.00588,0.00408,0.00499,-0.00163,0.00579,0.00412],[0.05241,0.01751,0.01180,0.00110,0.00749,0.01002,0.01187,0.00065,0.00805,0.00956,0.01128,-0.00191,0.00959,0.00772,0.01104,-0.00188,0.00957,0.00751,0.02555,0.01213,0.00726,-0.00062,0.00648,0.00669,0.00691,-0.00018,0.00622,0.00672,0.00605,-0.00152,0.00690,0.00534,0.00579,-0.00149,0.00687,0.00525],[0.08688,0.02138,0.01718,0.00352,0.00780,0.01340,0.01795,-0.00061,0.01050,0.01128,0.01912,-0.00277,0.01260,0.00997,0.01856,-0.00210,0.01237,0.00971,0.05637,0.01476,0.01116,0.00125,0.00709,0.00934,0.01144,-0.00094,0.00827,0.00837,0.01130,-0.00251,0.00955,0.00684,0.01086,-0.00153,0.00907,0.00703],[-0.01251,-0.00518,-0.00218,0.00005,-0.00202,-0.00251,-0.00241,0.00136,-0.00283,-0.00191,-0.00277,0.00305,-0.00366,-0.00125,-0.00229,0.00062,-0.00259,-0.00195,-0.00061,-0.00133,-0.00047,0.00019,-0.00066,-0.00066,-0.00039,0.00105,-0.00117,-0.00021,-0.00067,0.00180,-0.00152,0.00007,-0.00033,-0.00055,-0.00035,-0.00083],[0.05819,0.01918,0.01223,0.00175,0.00782,0.01121,0.01267,-0.00127,0.00983,0.00959,0.01260,-0.00366,0.01140,0.00799,0.01210,-0.00269,0.01094,0.00801,0.02497,0.01181,0.00686,-0.00005,0.00591,0.00668,0.00673,-0.00146,0.00670,0.00592,0.00630,-0.00249,0.00731,0.00481,0.00594,-0.00133,0.00668,0.00522],[0.04526,0.01474,0.00994,0.00141,0.00600,0.00864,0.01010,-0.00015,0.00709,0.00772,0.00997,-0.00125,0.00799,0.00678,0.00989,-0.00212,0.00845,0.00622,0.02340,0.01013,0.00610,-0.00007,0.00519,0.00577,0.00604,-0.00056,0.00541,0.00549,0.00537,-0.00104,0.00575,0.00464,0.00526,-0.00171,0.00608,0.00427],[0.08555,0.02096,0.01691,0.00356,0.00761,0.01323,0.01770,-0.00064,0.01035,0.01104,0.01856,-0.00229,0.01210,0.00989,0.01894,-0.00265,0.01263,0.00954,0.05556,0.01447,0.01097,0.00131,0.00690,0.00917,0.01129,-0.00085,0.00803,0.00823,0.01107,-0.00190,0.00906,0.00690,0.01109,-0.00217,0.00931,0.00670],[-0.00979,-0.00493,-0.00204,-0.00008,-0.00211,-0.00274,-0.00225,0.00152,-0.00318,-0.00188,-0.00210,0.00062,-0.00269,-0.00212,-0.00265,0.00349,-0.00415,-0.00111,-0.00050,-0.00135,-0.00032,0.00043,-0.00079,-0.00054,-0.00028,0.00082,-0.00101,-0.00033,-0.00051,-0.00023,-0.00059,-0.00073,-0.00100,0.00195,-0.00170,-0.00003],[0.05768,0.01918,0.01224,0.00181,0.00786,0.01138,0.01271,-0.00144,0.01007,0.00957,0.01237,-0.00259,0.01094,0.00845,0.01263,-0.00415,0.01188,0.00767,0.02493,0.01166,0.00674,-0.00015,0.00589,0.00654,0.00663,-0.00136,0.00653,0.00590,0.00619,-0.00152,0.00676,0.00515,0.00627,-0.00255,0.00731,0.00474],[0.04422,0.01436,0.00962,0.00131,0.00581,0.00825,0.00974,0.00003,0.00666,0.00751,0.00971,-0.00195,0.00801,0.00622,0.00954,-0.00111,0.00767,0.00639,0.02192,0.00994,0.00592,0.00008,0.00496,0.00568,0.00585,-0.00059,0.00531,0.00529,0.00520,-0.00162,0.00587,0.00424,0.00495,-0.00091,0.00551,0.00445],[0.29289,0.04406,0.04987,0.01080,0.01344,0.02960,0.05218,0.00095,0.02047,0.02555,0.05637,-0.00061,0.02497,0.02340,0.05556,-0.00050,0.02493,0.02192,0.80888,0.07393,0.09136,0.00550,0.04098,0.05445,0.09418,0.00484,0.04118,0.05456,0.10108,-0.00172,0.05248,0.04773,0.09926,-0.00566,0.05574,0.04302],[0.06635,0.02391,0.01513,0.00307,0.00877,0.01291,0.01521,0.00167,0.00974,0.01213,0.01476,-0.00133,0.01181,0.01013,0.01447,-0.00135,0.01166,0.00994,0.07393,0.02515,0.01508,0.00062,0.01143,0.01343,0.01496,0.00017,0.01160,0.01305,0.01447,-0.00221,0.01347,0.01104,0.01389,-0.00257,0.01356,0.01061],[0.05233,0.01243,0.01122,0.00208,0.00475,0.00798,0.01129,0.00091,0.00568,0.00726,0.01116,-0.00047,0.00686,0.00610,0.01097,-0.00032,0.00674,0.00592,0.09136,0.01508,0.01607,0.00071,0.00840,0.01067,0.01506,0.00109,0.00793,0.01023,0.01440,-0.00041,0.00934,0.00845,0.01418,-0.00085,0.00968,0.00797],[0.00378,-0.00054,0.00035,0.00259,-0.00138,0.00065,0.00149,-0.00150,0.00061,-0.00062,0.00125,0.00019,-0.00005,-0.00007,0.00131,0.00043,-0.00015,0.00008,0.00550,0.00062,0.00071,0.00368,-0.00164,0.00154,0.00250,-0.00178,0.00142,-0.00006,0.00175,-0.00005,0.00053,0.00033,0.00187,0.00041,0.00031,0.00062],[0.03336,0.01144,0.00779,0.00026,0.00505,0.00629,0.00722,0.00155,0.00457,0.00648,0.00709,-0.00066,0.00591,0.00519,0.00690,-0.00079,0.00589,0.00496,0.04098,0.01143,0.00840,-0.00164,0.00714,0.00648,0.00718,0.00115,0.00532,0.00716,0.00699,-0.00085,0.00667,0.00576,0.00663,-0.00150,0.00703,0.00518],[0.04209,0.01234,0.00930,0.00261,0.00440,0.00776,0.00968,0.00029,0.00579,0.00669,0.00934,-0.00066,0.00668,0.00577,0.00917,-0.00054,0.00654,0.00568,0.05445,0.01343,0.01067,0.00154,0.00648,0.00898,0.01080,-0.00033,0.00742,0.00803,0.00984,-0.00101,0.00808,0.00680,0.00959,-0.00119,0.00819,0.00656],[0.05250,0.01216,0.01113,0.00319,0.00415,0.00819,0.01191,-0.00003,0.00600,0.00691,0.01144,-0.00039,0.00673,0.00604,0.01129,-0.00028,0.00663,0.00585,0.09418,0.01496,0.01506,0.00250,0.00718,0.01080,0.01779,-0.00029,0.00910,0.01049,0.01511,-0.00026,0.00932,0.00863,0.01483,-0.00095,0.00980,0.00794],[-0.00447,-0.00200,-0.00043,-0.00167,-0.00008,-0.00164,-0.00153,0.00258,-0.00228,-0.00018,-0.00094,0.00105,-0.00146,-0.00056,-0.00085,0.00082,-0.00136,-0.00059,0.00484,0.00017,0.00109,-0.00178,0.00115,-0.00033,-0.00029,0.00324,-0.00164,0.00122,0.00056,0.00085,-0.00023,0.00057,0.00046,0.00033,0.00004,0.00021],[0.03779,0.01219,0.00821,0.00251,0.00433,0.00752,0.00892,-0.00068,0.00619,0.00622,0.00827,-0.00117,0.00670,0.00541,0.00803,-0.00101,0.00653,0.00531,0.04118,0.01160,0.00793,0.00142,0.00532,0.00742,0.00910,-0.00164,0.00717,0.00646,0.00761,-0.00133,0.00704,0.00554,0.00738,-0.00138,0.00704,0.00541],[0.03800,0.01152,0.00875,0.00126,0.00467,0.00673,0.00853,0.00165,0.00461,0.00672,0.00837,-0.00021,0.00592,0.00549,0.00823,-0.00033,0.00590,0.00529,0.05456,0.01305,0.01023,-0.00006,0.00716,0.00803,0.01049,0.00122,0.00646,0.00866,0.00945,-0.00051,0.00761,0.00694,0.00905,-0.00126,0.00804,0.00625],[0.05121,0.01098,0.01003,0.00273,0.00358,0.00714,0.01053,0.00022,0.00515,0.00605,0.01130,-0.00067,0.00630,0.00537,0.01107,-0.00051,0.00619,0.00520,0.10108,0.01447,0.01440,0.00175,0.00699,0.00984,0.01511,0.00056,0.00761,0.00945,0.01738,-0.00102,0.01000,0.00851,0.01611,-0.00116,0.00986,0.00776],[-0.01076,-0.00372,-0.00187,-0.00046,-0.00118,-0.00191,-0.00206,0.00038,-0.00171,-0.00152,-0.00251,0.00180,-0.00249,-0.00104,-0.00190,-0.00023,-0.00152,-0.00162,-0.00172,-0.00221,-0.00041,-0.00005,-0.00085,-0.00101,-0.00026,0.00085,-0.00133,-0.00051,-0.00102,0.00247,-0.00224,-0.00008,-0.00040,-0.00057,-0.00070,-0.00115],[0.04313,0.01318,0.00910,0.00199,0.00484,0.00770,0.00932,0.00046,0.00588,0.00690,0.00955,-0.00152,0.00731,0.00575,0.00906,-0.00059,0.00676,0.00587,0.05248,0.01347,0.00934,0.00053,0.00667,0.00808,0.00932,-0.00023,0.00704,0.00761,0.01000,-0.00224,0.00876,0.00639,0.00907,-0.00114,0.00804,0.00645],[0.03095,0.00955,0.00702,0.00142,0.00369,0.00568,0.00700,0.00086,0.00408,0.00534,0.00684,0.00007,0.00481,0.00464,0.00690,-0.00073,0.00515,0.00424,0.04773,0.01104,0.00845,0.00033,0.00576,0.00680,0.00863,0.00057,0.00554,0.00694,0.00851,-0.00008,0.00639,0.00614,0.00808,-0.00160,0.00711,0.00513],[0.04956,0.01051,0.00967,0.00268,0.00343,0.00688,0.01017,0.00016,0.00499,0.00579,0.01086,-0.00033,0.00594,0.00526,0.01109,-0.00100,0.00627,0.00495,0.09926,0.01389,0.01418,0.00187,0.00663,0.00959,0.01483,0.00046,0.00738,0.00905,0.01611,-0.00040,0.00907,0.00808,0.01721,-0.00170,0.01021,0.00769],[-0.00685,-0.00323,-0.00164,-0.00049,-0.00119,-0.00189,-0.00164,0.00020,-0.00163,-0.00149,-0.00153,-0.00055,-0.00133,-0.00171,-0.00217,0.00195,-0.00255,-0.00091,-0.00566,-0.00257,-0.00085,0.00041,-0.00150,-0.00119,-0.00095,0.00033,-0.00138,-0.00126,-0.00116,-0.00057,-0.00114,-0.00160,-0.00170,0.00279,-0.00283,-0.00036],[0.04139,0.01288,0.00899,0.00200,0.00483,0.00767,0.00911,0.00055,0.00579,0.00687,0.00907,-0.00035,0.00668,0.00608,0.00931,-0.00170,0.00731,0.00551,0.05574,0.01356,0.00968,0.00031,0.00703,0.00819,0.00980,0.00004,0.00704,0.00804,0.00986,-0.00070,0.00804,0.00711,0.01021,-0.00283,0.00933,0.00614],[0.03192,0.00959,0.00694,0.00140,0.00363,0.00561,0.00704,0.00072,0.00412,0.00525,0.00703,-0.00083,0.00522,0.00427,0.00670,-0.00003,0.00474,0.00445,0.04302,0.01061,0.00797,0.00062,0.00518,0.00656,0.00794,0.00021,0.00541,0.00625,0.00776,-0.00115,0.00645,0.00513,0.00769,-0.00036,0.00614,0.00542]])

    self.model_param_mu = np.array([2.60131,0.90570,0.81205,0.09043,0.13873,0.20603,0.81897,0.06246,0.15333,0.19591,0.82647,-0.02553,0.18857,0.16578,0.82429,-0.02536,0.18724,0.16505,2.96949,0.96123,0.84935,0.08238,0.16132,0.22492,0.85895,0.05508,0.17531,0.21713,0.87208,-0.03222,0.21549,0.18821,0.86940,-0.03233,0.21474,0.18678])

